{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9e0c1743-8775-4de2-a599-78b050551489",
   "metadata": {},
   "source": [
    "# Language Agent Tree Search\n",
    "\n",
    "[Language Agent Tree Search](https://andyz245.github.io/LanguageAgentTreeSearch/) (LATS), by Zhou, et. al, is a general LLM agent search algorithm that combines reflection/evaluation and search (specifically monte-carlo trees search) to get achieve better overall task performance compared to similar techniques like ReACT, Reflexion, or Tree of Thoughts.\n",
    "\n",
    "![LATS diagram](./img/lats.png)\n",
    "\n",
    "It has four main steps:\n",
    "\n",
    "1. Select: pick the best next actions based on the aggreate rewards from step (2). Either respond (if a solution is found or the max search depth is reached) or continue searching.\n",
    "2. Expand and simulate: select the \"best\" 5 potential actions to take and execute them in parallel.\n",
    "3. Reflect + Evaluate: observe the outcomes of these actions and score the decisions based on reflection (and possibly external feedback)\n",
    "4. Backpropagate: update the scores of the root trajectories based on the outcomes."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db28668b-5491-4c93-a961-bd339f09202c",
   "metadata": {},
   "source": [
    "## 0. Prerequisites\n",
    "\n",
    "Install `langgraph` (for the framework), `langchain_openai` (for the LLM), and `langchain` + `tavily-python` (for the search engine).\n",
    "\n",
    "We will use tavily search as a tool. You can get an API key [here](https://app.tavily.com/sign-in) or replace with a different tool of your choosing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcc9159b-cc8c-426d-9670-3e8ada06723f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install -U --quiet  langchain langgraph langchain_openai\n",
    "# %pip install -U --quiet tavily-python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a177ecc9-0c96-460f-9b39-9c1ce54754f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import getpass\n",
    "import os\n",
    "\n",
    "\n",
    "def _set_if_undefined(var: str) -> None:\n",
    "    if os.environ.get(var):\n",
    "        return\n",
    "    os.environ[var] = getpass.getpass(var)\n",
    "\n",
    "\n",
    "# Optional: Configure tracing to visualize and debug the agent\n",
    "_set_if_undefined(\"LANGCHAIN_API_KEY\")\n",
    "os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
    "os.environ[\"LANGCHAIN_PROJECT\"] = \"LATS\"\n",
    "\n",
    "_set_if_undefined(\"OPENAI_API_KEY\")\n",
    "_set_if_undefined(\"TAVILY_API_KEY\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f857eacb-af4a-47d1-b45f-da74941125c2",
   "metadata": {},
   "source": [
    "## Graph State\n",
    "\n",
    "LATS is based on a  (greedy) Monte-Carlo tree search. For each search steps, it picks the node with the highest \"upper confidence bound\", which is a metric that balances exploitation (highest average reward) and exploration (lowest visits). Starting from that node, it generates N (5 in this case) new candidate actions to take, and adds them to the tree. It stops searching either when it has generated a valid solution OR when it has reached the maximum number of rollouts (search tree depth).\n",
    "\n",
    "![Tree Diagram](./img/tree.png)\n",
    "\n",
    "Our LangGraph state will be composed of two items:\n",
    "1. The root of the search tree\n",
    "2. The user input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "54c6f319-3966-4f66-aa7b-50e249189111",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "import math\n",
    "from typing import List, Optional\n",
    "\n",
    "from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage\n",
    "\n",
    "\n",
    "class Node:\n",
    "    def __init__(\n",
    "        self,\n",
    "        messages: List[BaseMessage],\n",
    "        reflection: Reflection,\n",
    "        parent: Optional[Node] = None,\n",
    "    ):\n",
    "        self.messages = messages\n",
    "        self.parent = parent\n",
    "        self.children = []\n",
    "        self.value = 0\n",
    "        self.visits = 0\n",
    "        self.reflection = reflection\n",
    "        self.depth = parent.depth + 1 if parent is not None else 1\n",
    "        self._is_solved = reflection.found_solution if reflection else False\n",
    "        if self._is_solved:\n",
    "            self._mark_tree_as_solved()\n",
    "        self.backpropagate(reflection.normalized_score)\n",
    "\n",
    "    def __repr__(self) -> str:\n",
    "        return (\n",
    "            f\"<Node value={self.value}, visits={self.visits},\"\n",
    "            f\" solution={self.messages} reflection={self.reflection}/>\"\n",
    "        )\n",
    "\n",
    "    @property\n",
    "    def is_solved(self):\n",
    "        \"\"\"If any solutions exist, we can end the search.\"\"\"\n",
    "        return self._is_solved\n",
    "\n",
    "    @property\n",
    "    def is_terminal(self):\n",
    "        return not self.children\n",
    "\n",
    "    @property\n",
    "    def best_child(self):\n",
    "        \"\"\"Select the child with the highest UCT to search next.\"\"\"\n",
    "        if not self.children:\n",
    "            return None\n",
    "        all_nodes = self._get_all_children()\n",
    "        return max(all_nodes, key=lambda child: child.upper_confidence_bound())\n",
    "\n",
    "    @property\n",
    "    def best_child_score(self):\n",
    "        \"\"\"Return the child with the highest value.\"\"\"\n",
    "        if not self.children:\n",
    "            return None\n",
    "        return max(self.children, key=lambda child: int(child.is_solved) * child.value)\n",
    "\n",
    "    @property\n",
    "    def height(self) -> int:\n",
    "        \"\"\"Check for how far we've rolled out the tree.\"\"\"\n",
    "        if self.children:\n",
    "            return 1 + max([child.height for child in self.children])\n",
    "        return 1\n",
    "\n",
    "    def upper_confidence_bound(self, exploration_weight=1.0):\n",
    "        \"\"\"Return the UCT score. This helps balance exploration vs. exploitation of a branch.\"\"\"\n",
    "        if self.parent is None:\n",
    "            raise ValueError(\"Cannot obtain UCT from root node\")\n",
    "        if self.visits == 0:\n",
    "            return self.value\n",
    "        # Encourages exploitation of high-value trajectories\n",
    "        average_reward = self.value / self.visits\n",
    "        # Encourages exploration of less-visited trajectories\n",
    "        exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)\n",
    "        return average_reward + exploration_weight * exploration_term\n",
    "\n",
    "    def backpropagate(self, reward: float):\n",
    "        \"\"\"Update the score of this node and its parents.\"\"\"\n",
    "        node = self\n",
    "        while node:\n",
    "            node.visits += 1\n",
    "            node.value = (node.value * (node.visits - 1) + reward) / node.visits\n",
    "            node = node.parent\n",
    "\n",
    "    def get_messages(self, include_reflections: bool = True):\n",
    "        if include_reflections:\n",
    "            return self.messages + [self.reflection.as_message()]\n",
    "        return self.messages\n",
    "\n",
    "    def get_trajectory(self, include_reflections: bool = True) -> List[BaseMessage]:\n",
    "        \"\"\"Get messages representing this search branch.\"\"\"\n",
    "        messages = []\n",
    "        node = self\n",
    "        while node:\n",
    "            messages.extend(\n",
    "                node.get_messages(include_reflections=include_reflections)[::-1]\n",
    "            )\n",
    "            node = node.parent\n",
    "        # Reverse the final back-tracked trajectory to return in the correct order\n",
    "        return messages[::-1]  # root solution, reflection, child 1, ...\n",
    "\n",
    "    def _get_all_children(self):\n",
    "        all_nodes = []\n",
    "        nodes = deque()\n",
    "        nodes.append(self)\n",
    "        while nodes:\n",
    "            node = nodes.popleft()\n",
    "            all_nodes.extend(node.children)\n",
    "            for n in node.children:\n",
    "                nodes.append(n)\n",
    "        return all_nodes\n",
    "\n",
    "    def get_best_solution(self):\n",
    "        \"\"\"Return the best solution from within the current sub-tree.\"\"\"\n",
    "        all_nodes = [self] + self._get_all_children()\n",
    "        best_node = max(\n",
    "            all_nodes,\n",
    "            # We filter out all non-terminal, non-solution trajectories\n",
    "            key=lambda node: int(node.is_terminal and node.is_solved) * node.value,\n",
    "        )\n",
    "        return best_node\n",
    "\n",
    "    def _mark_tree_as_solved(self):\n",
    "        parent = self.parent\n",
    "        while parent:\n",
    "            parent._is_solved = True\n",
    "            parent = parent.parent"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdd3111f-b860-471f-8784-1d5e3783910d",
   "metadata": {},
   "source": [
    "#### The graph state itself\n",
    "\n",
    "The main component is the tree, represented by the root node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e10c94ba-9daa-4899-97ce-4f28428c2c38",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing_extensions import TypedDict\n",
    "\n",
    "\n",
    "class TreeState(TypedDict):\n",
    "    # The full tree\n",
    "    root: Node\n",
    "    # The original input\n",
    "    input: str"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e8ddf25-d040-4e1f-87bd-5837ff105845",
   "metadata": {},
   "source": [
    "## Define Language Agent\n",
    "\n",
    "Our agent will have three primary LLM-powered processes:\n",
    "1. Reflect: score the action based on the tool response.\n",
    "2. Initial response: to create the root node and start the search.\n",
    "3. Expand: generate 5 candidate \"next steps\" from the best spot in the current tree\n",
    "\n",
    "For more \"Grounded\" tool applications (such as code synthesis), you could integrate code execution into the reflection/reward step. This type of external feedback is very useful (though adds complexity to an already complicated example notebook)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "48738896-42ac-47eb-b482-0d4d4dd86c87",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "llm = ChatOpenAI(model=\"gpt-3.5-turbo\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d460856-e26d-4430-910e-0aac58563612",
   "metadata": {},
   "source": [
    "#### Tools\n",
    "\n",
    "For our example, we will give the language agent a search engine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "55c2aff3-f454-43da-8f45-1a3d46523cd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.tools.tavily_search import TavilySearchResults\n",
    "from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper\n",
    "\n",
    "from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation\n",
    "\n",
    "search = TavilySearchAPIWrapper()\n",
    "tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)\n",
    "tools = [tavily_tool]\n",
    "tool_executor = ToolExecutor(tools=tools)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c611f1e-74b4-4157-997c-face8ad409a4",
   "metadata": {},
   "source": [
    "### Reflection\n",
    "\n",
    "The reflection chain will score agent outputs based on the decision and the tool responses.\n",
    "We will call this within the other two nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ddfd1750-c265-4b29-b505-83b1c5e2d30e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chains import create_structured_output_runnable\n",
    "from langchain.output_parsers.openai_tools import (\n",
    "    JsonOutputToolsParser,\n",
    "    PydanticToolsParser,\n",
    ")\n",
    "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
    "from langchain_core.pydantic_v1 import BaseModel, Field\n",
    "from langchain_core.runnables import chain as as_runnable\n",
    "\n",
    "\n",
    "class Reflection(BaseModel):\n",
    "    reflections: str = Field(\n",
    "        description=\"The critique and reflections on the sufficiency, superfluency,\"\n",
    "        \" and general quality of the response\"\n",
    "    )\n",
    "    score: int = Field(\n",
    "        description=\"Score from 0-10 on the quality of the candidate response.\",\n",
    "        gte=0,\n",
    "        lte=10,\n",
    "    )\n",
    "    found_solution: bool = Field(\n",
    "        description=\"Whether the response has fully solved the question or task.\"\n",
    "    )\n",
    "\n",
    "    def as_message(self):\n",
    "        return HumanMessage(\n",
    "            content=f\"Reasoning: {self.reflections}\\nScore: {self.score}\"\n",
    "        )\n",
    "\n",
    "    @property\n",
    "    def normalized_score(self) -> float:\n",
    "        return self.score / 10.0\n",
    "\n",
    "\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            \"Reflect and grade the assistant response to the user question below.\",\n",
    "        ),\n",
    "        (\"user\", \"{input}\"),\n",
    "        MessagesPlaceholder(variable_name=\"candidate\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "reflection_llm_chain = (\n",
    "    prompt\n",
    "    | llm.bind_tools(tools=[Reflection], tool_choice=\"Reflection\").with_config(\n",
    "        run_name=\"Reflection\"\n",
    "    )\n",
    "    | PydanticToolsParser(tools=[Reflection])\n",
    ")\n",
    "\n",
    "\n",
    "@as_runnable\n",
    "def reflection_chain(inputs) -> Reflection:\n",
    "    tool_choices = reflection_llm_chain.invoke(inputs)\n",
    "    reflection = tool_choices[0]\n",
    "    if not isinstance(inputs[\"candidate\"][-1], AIMessage):\n",
    "        reflection.found_solution = False\n",
    "    return reflection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e47dfb2-4ab3-4a31-b117-f07786b357cb",
   "metadata": {},
   "source": [
    "### Initial Response\n",
    "\n",
    "We start with a single root node, generated by this first step. It responds to the user input either with a tool invocation or a response."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "72fc5363-f0f3-4362-8499-14eb583bd75b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "\n",
    "from langchain_core.prompt_values import ChatPromptValue\n",
    "from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError\n",
    "from langchain_core.runnables import RunnableConfig\n",
    "\n",
    "prompt_template = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            \"You are an AI assistant.\",\n",
    "        ),\n",
    "        (\"user\", \"{input}\"),\n",
    "        MessagesPlaceholder(variable_name=\"messages\", optional=True),\n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "initial_answer_chain = prompt_template | llm.bind_tools(tools=tools).with_config(\n",
    "    run_name=\"GenerateInitialCandidate\"\n",
    ")\n",
    "\n",
    "\n",
    "parser = JsonOutputToolsParser(return_id=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7207f913-a6db-4ef9-a98d-ecb8612b23d5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_APBQsd15wnSNPhyghCvFrNC8', 'function': {'arguments': '{\"query\":\"lithium pollution research report\"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}]})"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "initial_response = initial_answer_chain.invoke(\n",
    "    {\"input\": \"Write a research report on lithium pollution.\"}\n",
    ")\n",
    "initial_response"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a7d34a6-cee0-4321-989a-963ca4b2caeb",
   "metadata": {},
   "source": [
    "#### Starting Node\n",
    "\n",
    "We will package up the candidate generation and reflection in a single node of our graph. This is represented by the following function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "5b6b173c-78f5-4ae1-80b3-28c80e68f5c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "\n",
    "# Define the node we will add to the graph\n",
    "def generate_initial_response(state: TreeState) -> dict:\n",
    "    \"\"\"Generate the initial candidate response.\"\"\"\n",
    "    res = initial_answer_chain.invoke({\"input\": state[\"input\"]})\n",
    "    parsed = parser.invoke(res)\n",
    "    tool_responses = tool_executor.batch(\n",
    "        [ToolInvocation(tool=r[\"type\"], tool_input=r[\"args\"]) for r in parsed]\n",
    "    )\n",
    "    output_messages = [res] + [\n",
    "        ToolMessage(content=json.dumps(resp), tool_call_id=tool_call[\"id\"])\n",
    "        for resp, tool_call in zip(tool_responses, parsed)\n",
    "    ]\n",
    "    reflection = reflection_chain.invoke(\n",
    "        {\"input\": state[\"input\"], \"candidate\": output_messages}\n",
    "    )\n",
    "    root = Node(output_messages, reflection=reflection)\n",
    "    return {\n",
    "        **state,\n",
    "        \"root\": root,\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34452e88-e33a-474c-9623-075d1f434dda",
   "metadata": {},
   "source": [
    "### Candidate Generation\n",
    "\n",
    "The following code prompts the same LLM to generate N additional candidates to check."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "550bff9a-86aa-43ad-ad98-506e97c122d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This generates N candidate values\n",
    "# for a single input to sample actions from the environment\n",
    "\n",
    "\n",
    "def generate_candidates(messages: ChatPromptValue, config: RunnableConfig):\n",
    "    n = config[\"configurable\"].get(\"N\", 5)\n",
    "    bound_kwargs = llm.bind_tools(tools=tools).kwargs\n",
    "    chat_result = llm.generate(\n",
    "        [messages.to_messages()],\n",
    "        n=n,\n",
    "        callbacks=config[\"callbacks\"],\n",
    "        run_name=\"GenerateCandidates\",\n",
    "        **bound_kwargs\n",
    "    )\n",
    "    return [gen.message for gen in chat_result.generations[0]]\n",
    "\n",
    "\n",
    "expansion_chain = prompt_template | generate_candidates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "e368e61f-8150-4fd6-b3fd-208d1f0ddc9c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_5DMq9O6BIden7lLraFH0NuYZ', 'function': {'arguments': '{\"query\":\"lithium pollution research report\"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}]}),\n",
       " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_5DMq9O6BIden7lLraFH0NuYZ', 'function': {'arguments': '{\"query\":\"lithium pollution research report\"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}]}),\n",
       " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_5DMq9O6BIden7lLraFH0NuYZ', 'function': {'arguments': '{\"query\":\"lithium pollution research report\"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}]}),\n",
       " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_5DMq9O6BIden7lLraFH0NuYZ', 'function': {'arguments': '{\"query\":\"lithium pollution research report\"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}]}),\n",
       " AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_5DMq9O6BIden7lLraFH0NuYZ', 'function': {'arguments': '{\"query\":\"lithium pollution research report\"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}]})]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res = expansion_chain.invoke({\"input\": \"Write a research report on lithium pollution.\"})\n",
    "res"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88ecf775-29ed-4ebd-8297-d1aa3cda3f9b",
   "metadata": {},
   "source": [
    "#### Candidate generation node\n",
    "\n",
    "We will package the candidate generation and reflection steps in the following \"expand\" node.\n",
    "We do all the operations as a batch process to speed up execution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "d32af859-53e8-46be-8182-7d522be31f54",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict, deque\n",
    "\n",
    "\n",
    "def expand(state: TreeState, config: RunnableConfig) -> dict:\n",
    "    \"\"\"Starting from the \"best\" node in the tree, generate N candidates for the next step.\"\"\"\n",
    "    root = state[\"root\"]\n",
    "    best_candidate: Node = root.best_child if root.children else root\n",
    "    messages = best_candidate.get_trajectory()\n",
    "    # Generate N candidates from the single child candidate\n",
    "    new_candidates = expansion_chain.invoke(\n",
    "        {\"input\": state[\"input\"], \"messages\": messages}, config\n",
    "    )\n",
    "    parsed = parser.batch(new_candidates)\n",
    "    flattened = [\n",
    "        (i, tool_call)\n",
    "        for i, tool_calls in enumerate(parsed)\n",
    "        for tool_call in tool_calls\n",
    "    ]\n",
    "    tool_responses = tool_executor.batch(\n",
    "        [\n",
    "            ToolInvocation(tool=tool_call[\"type\"], tool_input=tool_call[\"args\"])\n",
    "            for _, tool_call in flattened\n",
    "        ]\n",
    "    )\n",
    "    collected_responses = defaultdict(list)\n",
    "    for (i, tool_call), resp in zip(flattened, tool_responses):\n",
    "        collected_responses[i].append(\n",
    "            ToolMessage(content=json.dumps(resp), tool_call_id=tool_call[\"id\"])\n",
    "        )\n",
    "    output_messages = []\n",
    "    for i, candidate in enumerate(new_candidates):\n",
    "        output_messages.append([candidate] + collected_responses[i])\n",
    "\n",
    "    # Reflect on each candidate\n",
    "    # For tasks with external validation, you'd add that here.\n",
    "    reflections = reflection_chain.batch(\n",
    "        [{\"input\": state[\"input\"], \"candidate\": msges} for msges in output_messages],\n",
    "        config,\n",
    "    )\n",
    "    # Grow tree\n",
    "    child_nodes = [\n",
    "        Node(cand, parent=best_candidate, reflection=reflection)\n",
    "        for cand, reflection in zip(output_messages, reflections)\n",
    "    ]\n",
    "    best_candidate.children.extend(child_nodes)\n",
    "    # We have already extended the tree directly, so we just return the state\n",
    "    return state"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84bad5da-645d-4c6a-83dd-8c852f21f622",
   "metadata": {},
   "source": [
    "## Create Graph\n",
    "\n",
    "With those two nodes defined, we are ready to define the graph. After each agent step, we have the option of finishing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "8aec0f20-f978-4df0-8900-e3a1f0544f6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.graph import END, StateGraph\n",
    "\n",
    "\n",
    "def should_loop(state: TreeState):\n",
    "    \"\"\"Determine whether to continue the tree search.\"\"\"\n",
    "    root = state[\"root\"]\n",
    "    if root.is_solved:\n",
    "        return END\n",
    "    if root.height > 5:\n",
    "        return END\n",
    "    return \"expand\"\n",
    "\n",
    "\n",
    "builder = StateGraph(TreeState)\n",
    "builder.add_node(\"start\", generate_initial_response)\n",
    "builder.add_node(\"expand\", expand)\n",
    "builder.set_entry_point(\"start\")\n",
    "\n",
    "\n",
    "builder.add_conditional_edges(\n",
    "    \"start\",\n",
    "    # Either expand/rollout or finish\n",
    "    should_loop,\n",
    ")\n",
    "builder.add_conditional_edges(\n",
    "    \"expand\",\n",
    "    # Either continue to rollout or finish\n",
    "    should_loop,\n",
    ")\n",
    "\n",
    "graph = builder.compile()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1383d69c-1d90-43f5-987e-c7fc4c3a24f8",
   "metadata": {},
   "source": [
    "## Invoke"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "92392fb3-8431-4649-9e78-2cc160e96ec1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start\n",
      "rolled out:  1\n",
      "---\n",
      "expand\n",
      "rolled out:  2\n",
      "---\n",
      "expand\n",
      "rolled out:  3\n",
      "---\n",
      "__end__\n",
      "rolled out:  3\n",
      "---\n"
     ]
    }
   ],
   "source": [
    "question = \"Generate a table with the average size and weight, as well as the oldest recorded instance for each of the top 5 most common birds.\"\n",
    "for step in graph.stream({\"input\": question}):\n",
    "    step_name, step_state = next(iter(step.items()))\n",
    "    print(step_name)\n",
    "    print(\"rolled out: \", step_state[\"root\"].height)\n",
    "    print(\"---\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "37a9e785-9909-4b56-b9be-da484e3711e1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The search results have provided detailed information on the average size and weight, as well as the oldest recorded instance for each of the top 5 most common birds: Northern Cardinal, Dark-eyed Junco, Mourning Dove, Downy Woodpecker, and House Finch. Now, I will compile this information into a table format for easy reference. Let's create the table with the average size and weight, as well as the oldest recorded instance for each of these birds.\n",
      "Here is the table with the average size and weight, as well as the oldest recorded instance for each of the top 5 most common birds:\n",
      "\n",
      "| Bird Species        | Average Size       | Average Weight   | Oldest Recorded Instance |\n",
      "|---------------------|--------------------|------------------|--------------------------|\n",
      "| Northern Cardinal   | 21.5 cm (male), 21.25 cm (female) | 42-48 g | 15 years and 9 months |\n",
      "| Dark-eyed Junco     | 14-16 cm            | 18-30 g          | At least 11 years, 4 months old |\n",
      "| Mourning Dove       | 22.5-36 cm          | 96-170 g         | 19 years                |\n",
      "| Downy Woodpecker    | 14-18 cm            | 20-33 g          | At least 11 years        |\n",
      "| House Finch         | 13-14 cm            | 16-27 g          | 8-11 years              |\n",
      "\n",
      "This table summarizes the average size and weight, as well as the oldest recorded instance for each of the top 5 most common birds.\n"
     ]
    }
   ],
   "source": [
    "solution_node = step[\"__end__\"][\"root\"].get_best_solution()\n",
    "best_trajectory = solution_node.get_trajectory(include_reflections=False)\n",
    "print(best_trajectory[-1].content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "1e084037-42e7-4f8e-962d-aaa3f04ab54c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start\n",
      "rolled out:  1\n",
      "---\n",
      "expand\n",
      "rolled out:  2\n",
      "---\n",
      "expand\n",
      "rolled out:  3\n",
      "---\n",
      "__end__\n",
      "rolled out:  3\n",
      "---\n"
     ]
    }
   ],
   "source": [
    "question = \"Write out magnus carlson series of moves in his game against Alireza Firouzja and propose an alternate strategy\"\n",
    "for step in graph.stream({\"input\": question}):\n",
    "    step_name, step_state = next(iter(step.items()))\n",
    "    print(step_name)\n",
    "    print(\"rolled out: \", step_state[\"root\"].height)\n",
    "    print(\"---\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "d403c1c8-b26b-4d79-87b1-d2d16c1a7673",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "In the game between Magnus Carlsen and Alireza Firouzja, Magnus Carlsen started with the move C4, and Alireza countered with H6. To propose an alternate strategy for Magnus Carlsen, focusing on positional play, creating strong pawn structures, and leveraging his endgame skills could be highly effective. Magnus could aim to control the center, develop his pieces harmoniously, and look for opportunities to gradually improve his position. By maintaining a solid pawn structure and maneuvering his pieces strategically, Magnus could aim to outmaneuver his opponent in the later stages of the game, utilizing his renowned endgame skills to secure a favorable outcome.\n"
     ]
    }
   ],
   "source": [
    "solution_node = step[\"__end__\"][\"root\"].get_best_solution()\n",
    "best_trajectory = solution_node.get_trajectory(include_reflections=False)\n",
    "print(best_trajectory[-1].content)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1b5140d-f51e-4032-8bc8-d7153252e3bf",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "Congrats on implementing LATS! This is a technique that can be reasonably ast and effective at solving complex reasoning tasks. A few notes that you probably observed above:\n",
    "1. While effective , the tree rollout can take additional compute time. If you wanted to include this in a production app, you'd either want to ensure that intermediate steps are streamed (so the user sees the thinking process/has access to intermediate results) or use it for fine-tuning data to improve the single-shot accuracy and avoid long rollouts.\n",
    "2. The candidate selection process is only as good as the reward you generate. Here we are using self-reflection exclusively, but if you have an external source of feedback (such as code test execution), that should be incorporated in the locations mentioned above."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6130dff9-4753-4556-a39e-330ac65ba9c6",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
